// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//


#include <pollux/exec/merge.h>
#include <pollux/common/testutil/test_value.h>
#include <pollux/exec/task.h>

using kumo::pollux::common::testutil::TestValue;

namespace kumo::pollux::exec {
    namespace {
        std::unique_ptr<VectorSerde::Options> getVectorSerdeOptions(
            const core::QueryConfig &queryConfig,
            VectorSerde::Kind kind) {
            std::unique_ptr<VectorSerde::Options> options =
                    kind == VectorSerde::Kind::kPresto
                        ? std::make_unique<serializer::presto::PrestoVectorSerde::PrestoOptions>()
                        : std::make_unique<VectorSerde::Options>();
            options->compressionKind =
                    common::stringToCompressionKind(queryConfig.shuffleCompressionKind());
            return options;
        }
    } // namespace

    Merge::Merge(
        int32_t operatorId,
        DriverCtx *driverCtx,
        RowTypePtr outputType,
        const std::vector<std::shared_ptr<const core::FieldAccessTypedExpr> > &
        sortingKeys,
        const std::vector<core::SortOrder> &sortingOrders,
        const std::string &planNodeId,
        const std::string &operatorType)
        : SourceOperator(
              driverCtx,
              std::move(outputType),
              operatorId,
              planNodeId,
              operatorType),
          outputBatchSize_{outputBatchRows()} {
        auto numKeys = sortingKeys.size();
        sortingKeys_.reserve(numKeys);
        for (int i = 0; i < numKeys; ++i) {
            auto channel = exprToChannel(sortingKeys[i].get(), outputType_);
            POLLUX_CHECK_NE(
                channel,
                kConstantChannel,
                "Merge doesn't allow constant grouping keys");
            sortingKeys_.emplace_back(
                channel,
                CompareFlags{
                    sortingOrders[i].isNullsFirst(),
                    sortingOrders[i].isAscending(),
                    false
                });
        }
    }

    void Merge::initializeTreeOfLosers() {
        std::vector<std::unique_ptr<SourceStream> > sourceCursors;
        sourceCursors.reserve(sources_.size());
        for (auto &source: sources_) {
            sourceCursors.push_back(std::make_unique<SourceStream>(
                source.get(), sortingKeys_, outputBatchSize_));
        }

        // Save the pointers to cursors before moving these into the TreeOfLosers.
        streams_.reserve(sources_.size());
        for (auto &cursor: sourceCursors) {
            streams_.push_back(cursor.get());
        }

        treeOfLosers_ =
                std::make_unique<TreeOfLosers<SourceStream> >(std::move(sourceCursors));
    }

    BlockingReason Merge::isBlocked(ContinueFuture *future) {
        TestValue::adjust("kumo::pollux::exec::Merge::isBlocked", this);

        auto reason = addMergeSources(future);
        if (reason != BlockingReason::kNotBlocked) {
            return reason;
        }

        // NOTE: the task might terminate early which leaves empty sources. Once it
        // happens, we shall simply mark the merge operator as finished.
        if (sources_.empty()) {
            finished_ = true;
            return BlockingReason::kNotBlocked;
        }

        // No merging is needed if there is only one source.
        if (streams_.empty() && sources_.size() > 1) {
            initializeTreeOfLosers();
        }

        if (sourceBlockingFutures_.empty()) {
            for (auto &cursor: streams_) {
                cursor->isBlocked(sourceBlockingFutures_);
            }
        }

        if (!sourceBlockingFutures_.empty()) {
            *future = std::move(sourceBlockingFutures_.back());
            sourceBlockingFutures_.pop_back();
            return BlockingReason::kWaitForProducer;
        }

        return BlockingReason::kNotBlocked;
    }

    bool Merge::isFinished() {
        return finished_;
    }

    RowVectorPtr Merge::getOutput() {
        if (finished_) {
            return nullptr;
        }

        // No merging is needed if there is only one source.
        if (sources_.size() == 1) {
            ContinueFuture future;
            RowVectorPtr data;
            auto reason = sources_[0]->next(data, &future);
            if (reason != BlockingReason::kNotBlocked) {
                sourceBlockingFutures_.emplace_back(std::move(future));
                return nullptr;
            }

            finished_ = data == nullptr;
            return data;
        }

        if (!output_) {
            output_ = BaseVector::create<RowVector>(
                outputType_, outputBatchSize_, operatorCtx_->pool());
            for (auto &child: output_->children()) {
                child->resize(outputBatchSize_);
            }
        }

        for (;;) {
            auto stream = treeOfLosers_->next();

            if (!stream) {
                finished_ = true;

                // Return nullptr if there is no data.
                if (outputSize_ == 0) {
                    return nullptr;
                }

                output_->resize(outputSize_);
                return std::move(output_);
            }

            if (stream->setOutputRow(outputSize_)) {
                // The stream is at end of input batch. Need to copy out the rows before
                // fetching next batch in 'pop'.
                stream->copyToOutput(output_);
            }

            ++outputSize_;

            // Advance the stream.
            stream->pop(sourceBlockingFutures_);

            if (outputSize_ == outputBatchSize_) {
                // Copy out data from all sources.
                for (auto &s: streams_) {
                    s->copyToOutput(output_);
                }

                outputSize_ = 0;
                return std::move(output_);
            }

            if (!sourceBlockingFutures_.empty()) {
                return nullptr;
            }
        }
    }

    void Merge::close() {
        for (auto &source: sources_) {
            source->close();
        }
    }

    bool SourceStream::operator<(const MergeStream &other) const {
        const auto &otherCursor = static_cast<const SourceStream &>(other);
        for (auto i = 0; i < sortingKeys_.size(); ++i) {
            const auto &[_, compareFlags] = sortingKeys_[i];
            POLLUX_DCHECK(
                compareFlags.nullAsValue(), "not supported null handling mode");
            if (auto result = keyColumns_[i]
                    ->compare(
                        otherCursor.keyColumns_[i],
                        currentSourceRow_,
                        otherCursor.currentSourceRow_,
                        compareFlags)
                    .value()) {
                return result < 0;
            }
        }
        return false;
    }

    bool SourceStream::pop(std::vector<ContinueFuture> &futures) {
        ++currentSourceRow_;
        if (currentSourceRow_ == data_->size()) {
            // Make sure all current data has been copied out.
            POLLUX_CHECK(!outputRows_.hasSelections());
            return fetchMoreData(futures);
        }

        return false;
    }

    void SourceStream::copyToOutput(RowVectorPtr &output) {
        outputRows_.updateBounds();

        if (!outputRows_.hasSelections()) {
            return;
        }

        vector_size_t sourceRow = firstSourceRow_;
        outputRows_.applyToSelected(
            [&](auto row) { sourceRows_[row] = sourceRow++; });

        for (auto i = 0; i < output->type()->size(); ++i) {
            output->childAt(i)->copy(
                data_->childAt(i).get(), outputRows_, sourceRows_.data());
        }

        outputRows_.clearAll();

        if (sourceRow == data_->size()) {
            firstSourceRow_ = 0;
        } else {
            firstSourceRow_ = sourceRow;
        }
    }

    bool SourceStream::fetchMoreData(std::vector<ContinueFuture> &futures) {
        ContinueFuture future;
        auto reason = source_->next(data_, &future);
        if (reason != BlockingReason::kNotBlocked) {
            needData_ = true;
            futures.emplace_back(std::move(future));
            return true;
        }

        atEnd_ = !data_ || data_->size() == 0;
        needData_ = false;
        currentSourceRow_ = 0;

        if (!atEnd_) {
            for (auto &child: data_->children()) {
                child = BaseVector::loaded_vector_shared(child);
            }
            keyColumns_.clear();
            for (const auto &key: sortingKeys_) {
                keyColumns_.push_back(data_->childAt(key.first).get());
            }
        }
        return false;
    }

    LocalMerge::LocalMerge(
        int32_t operatorId,
        DriverCtx *driverCtx,
        const std::shared_ptr<const core::LocalMergeNode> &localMergeNode)
        : Merge(
            operatorId,
            driverCtx,
            localMergeNode->outputType(),
            localMergeNode->sortingKeys(),
            localMergeNode->sortingOrders(),
            localMergeNode->id(),
            "LocalMerge") {
        POLLUX_CHECK_EQ(
            operatorCtx_->driverCtx()->driverId,
            0,
            "LocalMerge needs to run single-threaded");
    }

    BlockingReason LocalMerge::addMergeSources(ContinueFuture * /* future */) {
        if (sources_.empty()) {
            sources_ = operatorCtx_->task()->getLocalMergeSources(
                operatorCtx_->driverCtx()->splitGroupId, planNodeId());
        }
        return BlockingReason::kNotBlocked;
    }

    MergeExchange::MergeExchange(
        int32_t operatorId,
        DriverCtx *driverCtx,
        const std::shared_ptr<const core::MergeExchangeNode> &mergeExchangeNode)
        : Merge(
              operatorId,
              driverCtx,
              mergeExchangeNode->outputType(),
              mergeExchangeNode->sortingKeys(),
              mergeExchangeNode->sortingOrders(),
              mergeExchangeNode->id(),
              "MergeExchange"),
          serde_(getNamedVectorSerde(mergeExchangeNode->serdeKind())),
          serdeOptions_(getVectorSerdeOptions(
              driverCtx->queryConfig(),
              mergeExchangeNode->serdeKind())) {
    }

    BlockingReason MergeExchange::addMergeSources(ContinueFuture *future) {
        if (operatorCtx_->driverCtx()->driverId != 0) {
            // When there are multiple pipelines, a single operator, the one from
            // pipeline 0, is responsible for merging pages.
            return BlockingReason::kNotBlocked;
        }
        if (noMoreSplits_) {
            return BlockingReason::kNotBlocked;
        }
        for (;;) {
            exec::Split split;
            auto reason = operatorCtx_->task()->getSplitOrFuture(
                operatorCtx_->driverCtx()->splitGroupId, planNodeId(), split, *future);
            if (reason == BlockingReason::kNotBlocked) {
                if (split.hasConnectorSplit()) {
                    auto remoteSplit = std::dynamic_pointer_cast<RemoteConnectorSplit>(
                        split.connectorSplit);
                    POLLUX_CHECK(remoteSplit, "Wrong type of split");
                    remoteSourceTaskIds_.push_back(remoteSplit->taskId);
                } else {
                    noMoreSplits_ = true;
                    if (!remoteSourceTaskIds_.empty()) {
                        const auto maxMergeExchangeBufferSize =
                                operatorCtx_->driverCtx()
                                ->queryConfig()
                                .maxMergeExchangeBufferSize();
                        const auto maxQueuedBytesPerSource = std::min<int64_t>(
                            std::max<int64_t>(
                                maxMergeExchangeBufferSize / remoteSourceTaskIds_.size(),
                                MergeSource::kMaxQueuedBytesLowerLimit),
                            MergeSource::kMaxQueuedBytesUpperLimit);
                        for (uint32_t remoteSourceIndex = 0;
                             remoteSourceIndex < remoteSourceTaskIds_.size();
                             ++remoteSourceIndex) {
                            auto *pool = operatorCtx_->task()->addMergeSourcePool(
                                operatorCtx_->planNodeId(),
                                operatorCtx_->driverCtx()->pipelineId,
                                remoteSourceIndex);
                            sources_.emplace_back(MergeSource::createMergeExchangeSource(
                                this,
                                remoteSourceTaskIds_[remoteSourceIndex],
                                operatorCtx_->task()->destination(),
                                maxQueuedBytesPerSource,
                                pool,
                                operatorCtx_->task()->queryCtx()->executor()));
                        }
                    }
                    // TODO Delay this call until all input data has been processed.
                    operatorCtx_->task()->multipleSplitsFinished(
                        false, remoteSourceTaskIds_.size(), 0);
                    return BlockingReason::kNotBlocked;
                }
            } else {
                return reason;
            }
        }
    }

    void MergeExchange::close() {
        for (auto &source: sources_) {
            source->close();
        }
        Operator::close(); {
            auto lockedStats = stats_.wlock();
            lockedStats->addRuntimeStat(
                Operator::kShuffleSerdeKind,
                RuntimeCounter(static_cast<int64_t>(serde_->kind())));
            lockedStats->addRuntimeStat(
                Operator::kShuffleCompressionKind,
                RuntimeCounter(static_cast<int64_t>(serdeOptions_->compressionKind)));
        }
    }
} // namespace kumo::pollux::exec
